{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DweYe9FcbMK_"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "AVV2e0XKbJeX"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sZfSvVcDo6GQ"
      },
      "source": [
        "# Load text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "giK0nMbZFnoR"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/load_data/text\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/load_data/text.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/load_data/text.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/load_data/text.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dwlfPb11GH8J"
      },
      "source": [
        "This tutorial demonstrates two ways to load and preprocess text.\n",
        "\n",
        "+ First, you will use Keras utilities and layers. If you are new to TensorFlow, you should start with these.\n",
        "\n",
        "+ Next, you will use lower-level utilities like `tf.data.TextLineDataset` to load text files, and `tf.text` to preprocess the data for finer-grain control."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sa6IKWvADqH7"
      },
      "outputs": [],
      "source": [
        "# Be sure you're using the stable versions of both tf and tf-text, for binary compatibility.\n",
        "!pip install -q -U tensorflow\n",
        "!pip install -q -U tensorflow-text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "baYFZMW_bJHh"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "import pathlib\n",
        "import re\n",
        "import string\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras import losses\n",
        "from tensorflow.keras import preprocessing\n",
        "from tensorflow.keras import utils\n",
        "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow_text as tf_text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Az-d_K5_HQ5k"
      },
      "source": [
        "## Example 1: Predict the tag for a Stack Overflow question\n",
        "\n",
        "As a first example, you will download a dataset of programming questions from Stack Overflow. Each question (\"How do I sort a dictionary by value?\") is labeled with exactly one tag (`Python`, `CSharp`, `JavaScript`, or `Java`). Your task is to develop a model that predicts the tag for a question. This is an example of multi-class classification, an important and widely applicable kind of machine learning problem."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tjC3yLa5IjP7"
      },
      "source": [
        "### Download and explore the dataset\n",
        "\n",
        "\n",
        "Next, you will download the dataset, and explore the directory structure."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8ELgzA6SHTuV"
      },
      "outputs": [],
      "source": [
        "data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz'\n",
        "dataset = utils.get_file(\n",
        "    'stack_overflow_16k.tar.gz',\n",
        "    data_url,\n",
        "    untar=True,\n",
        "    cache_dir='stack_overflow',\n",
        "    cache_subdir='')\n",
        "dataset_dir = pathlib.Path(dataset).parent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jIrPl5fUH2gb"
      },
      "outputs": [],
      "source": [
        "list(dataset_dir.iterdir())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fEoV7YByJoWQ"
      },
      "outputs": [],
      "source": [
        "train_dir = dataset_dir/'train'\n",
        "list(train_dir.iterdir())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3mxAN17MhEh0"
      },
      "source": [
        "The `train/csharp`, `train/java`, `train/python` and `train/javascript` directories contain many text files, each of which is a Stack Overflow question. Print a file and inspect the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Go1vTSGdJu08"
      },
      "outputs": [],
      "source": [
        "sample_file = train_dir/'python/1755.txt'\n",
        "with open(sample_file) as f:\n",
        "  print(f.read())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "deWBTkpJiO7D"
      },
      "source": [
        "### Load the dataset\n",
        "\n",
        "Next, you will load the data off disk and prepare it into a format suitable for training. To do so, you will use [text_dataset_from_directory](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory) utility to create a labeled `tf.data.Dataset`. If you're new to [tf.data](https://www.tensorflow.org/guide/data), it's a powerful collection of tools for building input pipelines.\n",
        "\n",
        "The `preprocessing.text_dataset_from_directory` expects a directory structure as follows.\n",
        "\n",
        "```\n",
        "train/\n",
        "...csharp/\n",
        "......1.txt\n",
        "......2.txt\n",
        "...java/\n",
        "......1.txt\n",
        "......2.txt\n",
        "...javascript/\n",
        "......1.txt\n",
        "......2.txt\n",
        "...python/\n",
        "......1.txt\n",
        "......2.txt\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dyl6JTAjlbQV"
      },
      "source": [
        "When running a machine learning experiment, it is a best practice to divide your dataset into three splits: [train](https://developers.google.com/machine-learning/glossary#training_set), [validation](https://developers.google.com/machine-learning/glossary#validation_set), and [test](https://developers.google.com/machine-learning/glossary#test-set). The Stack Overflow dataset has already been divided into train and test, but it lacks a validation set. Create a validation set using an 80:20 split of the training data by using the `validation_split` argument below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qqyliMw8N-az"
      },
      "outputs": [],
      "source": [
        "batch_size = 32\n",
        "seed = 42\n",
        "\n",
        "raw_train_ds = preprocessing.text_dataset_from_directory(\n",
        "    train_dir,\n",
        "    batch_size=batch_size,\n",
        "    validation_split=0.2,\n",
        "    subset='training',\n",
        "    seed=seed)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DMI_gPLfloD7"
      },
      "source": [
        "As you can see above, there are 8,000 examples in the training folder, of which you will use 80% (or 6,400) for training. As you will see in a moment, you can train a model by passing a `tf.data.Dataset` directly to `model.fit`. First, iterate over the dataset and print out a few examples, to get a feel for the data.\n",
        "\n",
        "Note: To increase the difficulty of the classification problem, the dataset author replaced occurrences of the words *Python*, *CSharp*, *JavaScript*, or *Java* in the programming question with the word *blank*."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_JMTyZ6Glt_C"
      },
      "outputs": [],
      "source": [
        "for text_batch, label_batch in raw_train_ds.take(1):\n",
        "  for i in range(10):\n",
        "    print(\"Question: \", text_batch.numpy()[i])\n",
        "    print(\"Label:\", label_batch.numpy()[i])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jCZGl4Q5l2sS"
      },
      "source": [
        "The labels are `0`, `1`, `2` or `3`. To see which of these correspond to which string label, you can check the `class_names` property on the dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gIpCS7YjmGkj"
      },
      "outputs": [],
      "source": [
        "for i, label in enumerate(raw_train_ds.class_names):\n",
        "  print(\"Label\", i, \"corresponds to\", label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oUsdn-37qol9"
      },
      "source": [
        "Next, you will create a validation and test dataset. You will use the remaining 1,600 reviews from the training set for validation.\n",
        "\n",
        "Note:  When using the `validation_split` and `subset` arguments, make sure to either specify a random seed, or to pass `shuffle=False`, so that the validation and training splits have no overlap."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x7m6sCWJQuYt"
      },
      "outputs": [],
      "source": [
        "raw_val_ds = preprocessing.text_dataset_from_directory(\n",
        "    train_dir,\n",
        "    batch_size=batch_size,\n",
        "    validation_split=0.2,\n",
        "    subset='validation',\n",
        "    seed=seed)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BXMZc7fMQwKE"
      },
      "outputs": [],
      "source": [
        "test_dir = dataset_dir/'test'\n",
        "raw_test_ds = preprocessing.text_dataset_from_directory(\n",
        "    test_dir, batch_size=batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xdt-ATrGRGDL"
      },
      "source": [
        "### Prepare the dataset for training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HSDO9JZYrK-O"
      },
      "source": [
        "Note: The Preprocessing APIs used in this section are experimental in TensorFlow 2.3 and subject to change."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N6fRti45Rlj8"
      },
      "source": [
        "Next, you will standardize, tokenize, and vectorize the data using the `preprocessing.TextVectorization` layer.\n",
        "* Standardization refers to preprocessing the text, typically to remove punctuation or HTML elements to simplify the dataset.\n",
        "\n",
        "* Tokenization refers to splitting strings into tokens (for example, splitting a sentence into individual words by splitting on whitespace).\n",
        "\n",
        "* Vectorization refers to converting tokens into numbers so they can be fed into a neural network.\n",
        "\n",
        "All of these tasks can be accomplished with this layer. You can learn more about each of these in the [API doc](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing/TextVectorization).\n",
        "\n",
        "* The default standardization converts text to lowercase and removes punctuation.\n",
        "\n",
        "* The default tokenizer splits on whitespace.\n",
        "\n",
        "* The default vectorization mode is `int`. This outputs integer indices (one per token). This mode can be used to build models that take word order into account. You can also use other modes, like `binary`, to build bag-of-word models.\n",
        "\n",
        "\n",
        "You will build two modes to learn more about these. First, you will use the `binary` model to build a bag-of-words model. Next, you will use the `int` mode with a 1D ConvNet."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "voaC43rZR0jc"
      },
      "outputs": [],
      "source": [
        "VOCAB_SIZE = 10000\n",
        "\n",
        "binary_vectorize_layer = TextVectorization(\n",
        "    max_tokens=VOCAB_SIZE,\n",
        "    output_mode='binary')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ifDPFxuf2Hfz"
      },
      "source": [
        "For `int` mode, in addition to maximum vocabulary size, you need to set an explicit maximum sequence length, which will cause the layer to pad or truncate sequences to exactly sequence_length values."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XWsY01Zl2aRe"
      },
      "outputs": [],
      "source": [
        "MAX_SEQUENCE_LENGTH = 250\n",
        "\n",
        "int_vectorize_layer = TextVectorization(\n",
        "    max_tokens=VOCAB_SIZE,\n",
        "    output_mode='int',\n",
        "    output_sequence_length=MAX_SEQUENCE_LENGTH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ts6h9b5atD-Y"
      },
      "source": [
        "Next, you will call `adapt` to fit the state of the preprocessing layer to the dataset. This will cause the model to build an index of strings to integers.\n",
        "\n",
        "Note: it's important to only use your training data when calling adapt (using the test set would leak information)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yTXsdDEqSf9e"
      },
      "outputs": [],
      "source": [
        "# Make a text-only dataset (without labels), then call adapt\n",
        "train_text = raw_train_ds.map(lambda text, labels: text)\n",
        "binary_vectorize_layer.adapt(train_text)\n",
        "int_vectorize_layer.adapt(train_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XKVO6Jg7Sls0"
      },
      "source": [
        "See the result of using these layers to preprocess data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RngfPyArSsvM"
      },
      "outputs": [],
      "source": [
        "def binary_vectorize_text(text, label):\n",
        "  text = tf.expand_dims(text, -1)\n",
        "  return binary_vectorize_layer(text), label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_1W54wf0LhQ0"
      },
      "outputs": [],
      "source": [
        "def int_vectorize_text(text, label):\n",
        "  text = tf.expand_dims(text, -1)\n",
        "  return int_vectorize_layer(text), label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vi_sElMiSmXe"
      },
      "outputs": [],
      "source": [
        "# Retrieve a batch (of 32 reviews and labels) from the dataset\n",
        "text_batch, label_batch = next(iter(raw_train_ds))\n",
        "first_question, first_label = text_batch[0], label_batch[0]\n",
        "print(\"Question\", first_question)\n",
        "print(\"Label\", first_label)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UGukZoYv2v3v"
      },
      "outputs": [],
      "source": [
        "print(\"'binary' vectorized question:\", \n",
        "      binary_vectorize_text(first_question, first_label)[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lu07FsIw2yH5"
      },
      "outputs": [],
      "source": [
        "print(\"'int' vectorized question:\",\n",
        "      int_vectorize_text(first_question, first_label)[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wgjeF9PdS7tN"
      },
      "source": [
        "As you can see above, `binary` mode returns an array denoting which tokens exist at least once in the input, while `int` mode replaces each token by an integer, thus preserving their order. You can lookup the token (string) that each integer corresponds to by calling `.get_vocabulary()` on the layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WpBnTZilS8wt"
      },
      "outputs": [],
      "source": [
        "print(\"1289 ---> \", int_vectorize_layer.get_vocabulary()[1289])\n",
        "print(\"313 ---> \", int_vectorize_layer.get_vocabulary()[313])\n",
        "print(\"Vocabulary size: {}\".format(len(int_vectorize_layer.get_vocabulary())))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0kHgPE_YwHvp"
      },
      "source": [
        "You are nearly ready to train your model. As a final preprocessing step, you will apply the `TextVectorization` layers you created earlier to the train, validation, and test dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "46LeHmnD55wJ"
      },
      "outputs": [],
      "source": [
        "binary_train_ds = raw_train_ds.map(binary_vectorize_text)\n",
        "binary_val_ds = raw_val_ds.map(binary_vectorize_text)\n",
        "binary_test_ds = raw_test_ds.map(binary_vectorize_text)\n",
        "\n",
        "int_train_ds = raw_train_ds.map(int_vectorize_text)\n",
        "int_val_ds = raw_val_ds.map(int_vectorize_text)\n",
        "int_test_ds = raw_test_ds.map(int_vectorize_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NHuAF8hYfP5Z"
      },
      "source": [
        "### Configure the dataset for performance\n",
        "\n",
        "These are two important methods you should use when loading data to make sure that I/O does not become blocking.\n",
        "\n",
        "`.cache()` keeps data in memory after it's loaded off disk. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache, which is more efficient to read than many small files.\n",
        "\n",
        "`.prefetch()` overlaps data preprocessing and model execution while training. \n",
        "\n",
        "You can learn more about both methods, as well as how to cache data to disk in the [data performance guide](https://www.tensorflow.org/guide/data_performance)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PabA9DFIfSz7"
      },
      "outputs": [],
      "source": [
        "AUTOTUNE = tf.data.AUTOTUNE\n",
        "\n",
        "def configure_dataset(dataset):\n",
        "  return dataset.cache().prefetch(buffer_size=AUTOTUNE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J8GcJLvb3JH0"
      },
      "outputs": [],
      "source": [
        "binary_train_ds = configure_dataset(binary_train_ds)\n",
        "binary_val_ds = configure_dataset(binary_val_ds)\n",
        "binary_test_ds = configure_dataset(binary_test_ds)\n",
        "\n",
        "int_train_ds = configure_dataset(int_train_ds)\n",
        "int_val_ds = configure_dataset(int_val_ds)\n",
        "int_test_ds = configure_dataset(int_test_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NYGb7z_bfpGm"
      },
      "source": [
        "### Train the model\n",
        "It's time to create our neural network. For the `binary` vectorized data, train a simple bag-of-words linear model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2q8iAU-VMzaN"
      },
      "outputs": [],
      "source": [
        "binary_model = tf.keras.Sequential([layers.Dense(4)])\n",
        "binary_model.compile(\n",
        "    loss=losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])\n",
        "history = binary_model.fit(\n",
        "    binary_train_ds, validation_data=binary_val_ds, epochs=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EwidD-SwNIkz"
      },
      "source": [
        "Next, you will use the `int` vectorized layer to build a 1D ConvNet."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5ztw2XH_LbVz"
      },
      "outputs": [],
      "source": [
        "def create_model(vocab_size, num_labels):\n",
        "  model = tf.keras.Sequential([\n",
        "      layers.Embedding(vocab_size, 64, mask_zero=True),\n",
        "      layers.Conv1D(64, 5, padding=\"valid\", activation=\"relu\", strides=2),\n",
        "      layers.GlobalMaxPooling1D(),\n",
        "      layers.Dense(num_labels)\n",
        "  ])\n",
        "  return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s9rG1cFRL31Z"
      },
      "outputs": [],
      "source": [
        "# vocab_size is VOCAB_SIZE + 1 since 0 is used additionally for padding.\n",
        "int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=4)\n",
        "int_model.compile(\n",
        "    loss=losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])\n",
        "history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x3J9Eeuv97zE"
      },
      "source": [
        "Compare the two models:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N8ViDXw99v_u"
      },
      "outputs": [],
      "source": [
        "print(\"Linear model on binary vectorized data:\")\n",
        "print(binary_model.summary())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P9BOeoCwborD"
      },
      "outputs": [],
      "source": [
        "print(\"ConvNet model on int vectorized data:\")\n",
        "print(int_model.summary())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zYYW9tUdCtTy"
      },
      "source": [
        "Evaluate both models on the test data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5dTc4nZqf7fK"
      },
      "outputs": [],
      "source": [
        "binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds)\n",
        "int_loss, int_accuracy = int_model.evaluate(int_test_ds)\n",
        "\n",
        "print(\"Binary model accuracy: {:2.2%}\".format(binary_accuracy))\n",
        "print(\"Int model accuracy: {:2.2%}\".format(int_accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F9dhj8Hey9DS"
      },
      "source": [
        "Note: This example dataset represents a rather simple classification problem. More complex datasets and problems bring out subtle but significant differences in preprocessing strategies and model architectures. Be sure to try out different hyperparameters and epochs to compare various approaches."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h9GaXTsIgP-3"
      },
      "source": [
        "### Export the model\n",
        "\n",
        "In the code above, you applied the `TextVectorization` layer to the dataset before feeding text to the model. If you want to make your model capable of processing raw strings (for example, to simplify deploying it), you can include the `TextVectorization` layer inside your model. To do so, you can create a new model using the weights you just trained."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_bRe3KX8gRCX"
      },
      "outputs": [],
      "source": [
        "export_model = tf.keras.Sequential(\n",
        "    [binary_vectorize_layer, binary_model,\n",
        "     layers.Activation('sigmoid')])\n",
        "\n",
        "export_model.compile(\n",
        "    loss=losses.SparseCategoricalCrossentropy(from_logits=False),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])\n",
        "\n",
        "# Test it with `raw_test_ds`, which yields raw strings\n",
        "loss, accuracy = export_model.evaluate(raw_test_ds)\n",
        "print(\"Accuracy: {:2.2%}\".format(binary_accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m2eqTVBP4DUN"
      },
      "source": [
        "Now your model can take raw strings as input and predict a score for each label using `model.predict`. Define a function to find the label with the maximum score:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GU53uRXz45iO"
      },
      "outputs": [],
      "source": [
        "def get_string_labels(predicted_scores_batch):\n",
        "  predicted_int_labels = tf.argmax(predicted_scores_batch, axis=1)\n",
        "  predicted_labels = tf.gather(raw_train_ds.class_names, predicted_int_labels)\n",
        "  return predicted_labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yqnWc7Nn5eou"
      },
      "source": [
        "### Run inference on new data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BOR2MupW1_zS"
      },
      "outputs": [],
      "source": [
        "inputs = [\n",
        "    \"how do I extract keys from a dict into a list?\",  # python\n",
        "    \"debug public static void main(string[] args) {...}\",  # java\n",
        "]\n",
        "predicted_scores = export_model.predict(inputs)\n",
        "predicted_labels = get_string_labels(predicted_scores)\n",
        "for input, label in zip(inputs, predicted_labels):\n",
        "  print(\"Question: \", input)\n",
        "  print(\"Predicted label: \", label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0QDVfii_4slI"
      },
      "source": [
        "Including the text preprocessing logic inside your model enables you to export a model for production that simplifies deployment, and reduces the potential for [train/test skew](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew).\n",
        "\n",
        "There is a performance difference to keep in mind when choosing where to apply your `TextVectorization` layer. Using it outside of your model enables you to do asynchronous CPU processing and buffering of your data when training on GPU. So, if you're training your model on the GPU, you probably want to go with this option to get the best performance while developing your model, then switch to including the TextVectorization layer inside your model when you're ready to prepare for deployment.\n",
        "\n",
        "Visit this [tutorial](https://www.tensorflow.org/tutorials/keras/save_and_load) to learn more about saving models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p4cvuFzavTRy"
      },
      "source": [
        "## Example 2: Predict the author of Illiad translations \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fOlJ22508RIe"
      },
      "source": [
        "The following provides an example of using `tf.data.TextLineDataset` to load examples from text files, and `tf.text` to preprocess the data. In this example, you will use three different English translations of the same work, Homer's Illiad, and train a model to identify the translator given a single line of text."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-pCgKbOSk7kU"
      },
      "source": [
        "### Download and explore the dataset\n",
        "\n",
        "The texts of the three translations are by:\n",
        "\n",
        " - [William Cowper](https://en.wikipedia.org/wiki/William_Cowper) — [text](https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt)\n",
        "\n",
        " - [Edward, Earl of Derby](https://en.wikipedia.org/wiki/Edward_Smith-Stanley,_14th_Earl_of_Derby) — [text](https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt)\n",
        "\n",
        "- [Samuel Butler](https://en.wikipedia.org/wiki/Samuel_Butler_%28novelist%29) — [text](https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt)\n",
        "\n",
        "The text files used in this tutorial have undergone some typical preprocessing tasks like removing document header and footer, line numbers and chapter titles. Download these lightly munged files locally."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4YlKQthEYlFw"
      },
      "outputs": [],
      "source": [
        "DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'\n",
        "FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt']\n",
        "\n",
        "for name in FILE_NAMES:\n",
        "  text_dir = utils.get_file(name, origin=DIRECTORY_URL + name)\n",
        "\n",
        "parent_dir = pathlib.Path(text_dir).parent\n",
        "list(parent_dir.iterdir())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M8PHK5J_cXE5"
      },
      "source": [
        "### Load the dataset\n",
        "\n",
        "You will use `TextLineDataset`, which is designed to create a `tf.data.Dataset` from a text file in which each example is a line of text from the original file, whereas `text_dataset_from_directory` treats all contents of a file as a single example. `TextLineDataset` is useful for text data that is primarily line-based (for example, poetry or error logs). \n",
        "\n",
        "Iterate through these files, loading each one into its own dataset. Each example needs to be individually labeled, so use `tf.data.Dataset.map` to apply a labeler function to each one. This will iterate over every example in the dataset, returning (`example, label`) pairs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YIIWIdPXgk7I"
      },
      "outputs": [],
      "source": [
        "def labeler(example, index):\n",
        "  return example, tf.cast(index, tf.int64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Ajx7AmZnEg3"
      },
      "outputs": [],
      "source": [
        "labeled_data_sets = []\n",
        "\n",
        "for i, file_name in enumerate(FILE_NAMES):\n",
        "  lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name))\n",
        "  labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i))\n",
        "  labeled_data_sets.append(labeled_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wPOsVK1e9NGM"
      },
      "source": [
        "Next, you'll combine these labeled datasets into a single dataset, and shuffle it.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6jAeYkTIi9-2"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = 50000\n",
        "BATCH_SIZE = 64\n",
        "VALIDATION_SIZE = 5000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qd544E-Sh63L"
      },
      "outputs": [],
      "source": [
        "all_labeled_data = labeled_data_sets[0]\n",
        "for labeled_dataset in labeled_data_sets[1:]:\n",
        "  all_labeled_data = all_labeled_data.concatenate(labeled_dataset)\n",
        "\n",
        "all_labeled_data = all_labeled_data.shuffle(\n",
        "    BUFFER_SIZE, reshuffle_each_iteration=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r4JEHrJXeG5k"
      },
      "source": [
        "Print out a few examples as before. The dataset hasn't been batched yet, hence each entry in `all_labeled_data` corresponds to one data point:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gywKlN0xh6u5"
      },
      "outputs": [],
      "source": [
        "for text, label in all_labeled_data.take(10):\n",
        "  print(\"Sentence: \", text.numpy())\n",
        "  print(\"Label:\", label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5rrpU2_sfDh0"
      },
      "source": [
        "### Prepare the dataset for training\n",
        "\n",
        "Instead of using the Keras `TextVectorization` layer to preprocess our text dataset, you will now use the [`tf.text` API](https://www.tensorflow.org/tutorials/tensorflow_text/intro) to standardize and tokenize the data, build a vocabulary and use `StaticVocabularyTable` to map tokens to integers to feed to the model.\n",
        "\n",
        "While tf.text provides various tokenizers, you will use the `UnicodeScriptTokenizer` to tokenize our dataset. Define a function to convert the text to lower-case and tokenize it. You will use `tf.data.Dataset.map` to apply the tokenization to the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v4DpQW-Y12rm"
      },
      "outputs": [],
      "source": [
        "tokenizer = tf_text.UnicodeScriptTokenizer()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pz8xEj0ugu51"
      },
      "outputs": [],
      "source": [
        "def tokenize(text, unused_label):\n",
        "  lower_case = tf_text.case_fold_utf8(text)\n",
        "  return tokenizer.tokenize(lower_case)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vzUrAzOq31QL"
      },
      "outputs": [],
      "source": [
        "tokenized_ds = all_labeled_data.map(tokenize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jx4Q2i8XLV7o"
      },
      "source": [
        "You can iterate over the dataset and print out a few tokenized examples.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g2mkWri7LiGq"
      },
      "outputs": [],
      "source": [
        "for text_batch in tokenized_ds.take(5):\n",
        "  print(\"Tokens: \", text_batch.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JPd4PsskJ_Xt"
      },
      "source": [
        "Next, you will build a vocabulary by sorting tokens by frequency and keeping the top `VOCAB_SIZE` tokens."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YkHtbGnDh6mg"
      },
      "outputs": [],
      "source": [
        "tokenized_ds = configure_dataset(tokenized_ds)\n",
        "\n",
        "vocab_dict = collections.defaultdict(lambda: 0)\n",
        "for toks in tokenized_ds.as_numpy_iterator():\n",
        "  for tok in toks:\n",
        "    vocab_dict[tok] += 1\n",
        "\n",
        "vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True)\n",
        "vocab = [token for token, count in vocab]\n",
        "vocab = vocab[:VOCAB_SIZE]\n",
        "vocab_size = len(vocab)\n",
        "print(\"Vocab size: \", vocab_size)\n",
        "print(\"First five vocab entries:\", vocab[:5])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PyKSsaNAKi17"
      },
      "source": [
        "To convert the tokens into integers, use the `vocab` set to create a `StaticVocabularyTable`. You will map tokens to integers in the range [`2`, `vocab_size + 2`]. As with the `TextVectorization` layer, `0` is reserved to denote padding and `1` is reserved to denote an out-of-vocabulary (OOV) token."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kCBo2yFHD7y6"
      },
      "outputs": [],
      "source": [
        "keys = vocab\n",
        "values = range(2, len(vocab) + 2)  # reserve 0 for padding, 1 for OOV\n",
        "\n",
        "init = tf.lookup.KeyValueTensorInitializer(\n",
        "    keys, values, key_dtype=tf.string, value_dtype=tf.int64)\n",
        "\n",
        "num_oov_buckets = 1\n",
        "vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z5F-EiBpOADE"
      },
      "source": [
        "Finally, define a fuction to standardize, tokenize and vectorize the dataset using the tokenizer and lookup table:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HcIQ7LOTh6eT"
      },
      "outputs": [],
      "source": [
        "def preprocess_text(text, label):\n",
        "  standardized = tf_text.case_fold_utf8(text)\n",
        "  tokenized = tokenizer.tokenize(standardized)\n",
        "  vectorized = vocab_table.lookup(tokenized)\n",
        "  return vectorized, label"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v6S5Qyabi-vo"
      },
      "source": [
        "You can try this on a single example to see the output:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jgxPZaxUuTbk"
      },
      "outputs": [],
      "source": [
        "example_text, example_label = next(iter(all_labeled_data))\n",
        "print(\"Sentence: \", example_text.numpy())\n",
        "vectorized_text, example_label = preprocess_text(example_text, example_label)\n",
        "print(\"Vectorized sentence: \", vectorized_text.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p9qHM0v8k_Mg"
      },
      "source": [
        "Now run the preprocess function on the dataset using `tf.data.Dataset.map`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KmQVsAgJ-RM0"
      },
      "outputs": [],
      "source": [
        "all_encoded_data = all_labeled_data.map(preprocess_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_YZToSXSm0qr"
      },
      "source": [
        "### Split the dataset into train and test\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "itxIJwkrUXgv"
      },
      "source": [
        "The Keras `TextVectorization` layer also batches and pads the vectorized data. Padding is required because the examples inside of a batch need to be the same size and shape, but the examples in these datasets are not all the same size — each line of text has a different number of words. `tf.data.Dataset` supports splitting and padded-batching datasets: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r-rmbijQh6bf"
      },
      "outputs": [],
      "source": [
        "train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)\n",
        "validation_data = all_encoded_data.take(VALIDATION_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qTP0IwHBCn0Q"
      },
      "outputs": [],
      "source": [
        "train_data = train_data.padded_batch(BATCH_SIZE)\n",
        "validation_data = validation_data.padded_batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m-wmFq8uW1zS"
      },
      "source": [
        "Now, `validation_data` and `train_data` are not collections of (`example, label`) pairs, but collections of batches. Each batch is a pair of (*many examples*, *many labels*) represented as arrays. To illustrate:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kMslWfuwoqpB"
      },
      "outputs": [],
      "source": [
        "sample_text, sample_labels = next(iter(validation_data))\n",
        "print(\"Text batch shape: \", sample_text.shape)\n",
        "print(\"Label batch shape: \", sample_labels.shape)\n",
        "print(\"First text example: \", sample_text[0])\n",
        "print(\"First label example: \", sample_labels[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UI4I6_Sa0vWu"
      },
      "source": [
        "Since we use `0` for padding and `1` for out-of-vocabulary (OOV) tokens, the vocabulary size has increased by two."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u21LlkO8QGRX"
      },
      "outputs": [],
      "source": [
        "vocab_size += 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h44Ox11OYLP-"
      },
      "source": [
        "Configure the datasets for better performance as before."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BpT0b_7mYRXV"
      },
      "outputs": [],
      "source": [
        "train_data = configure_dataset(train_data)\n",
        "validation_data = configure_dataset(validation_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K8SUhGFNsmRi"
      },
      "source": [
        "### Train the model\n",
        "You can train a model on this dataset as before."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QJgI1pow2YR9"
      },
      "outputs": [],
      "source": [
        "model = create_model(vocab_size=vocab_size, num_labels=3)\n",
        "model.compile(\n",
        "    optimizer='adam',\n",
        "    loss=losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'])\n",
        "history = model.fit(train_data, validation_data=validation_data, epochs=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KTPCYf_Jh6TH"
      },
      "outputs": [],
      "source": [
        "loss, accuracy = model.evaluate(validation_data)\n",
        "\n",
        "print(\"Loss: \", loss)\n",
        "print(\"Accuracy: {:2.2%}\".format(accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_knIsO-r4pHb"
      },
      "source": [
        "### Export the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FEuMLJA_Xiwo"
      },
      "source": [
        "To make our model capable to taking raw strings as input, you will create a `TextVectorization` layer that performs the same steps as our custom preprocessing function. Since you already trained a vocabulary, you can use `set_vocaublary` instead of `adapt` which trains a new vocabulary."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_ODkRXbk6aHb"
      },
      "outputs": [],
      "source": [
        "preprocess_layer = TextVectorization(\n",
        "    max_tokens=vocab_size,\n",
        "    standardize=tf_text.case_fold_utf8,\n",
        "    split=tokenizer.tokenize,\n",
        "    output_mode='int',\n",
        "    output_sequence_length=MAX_SEQUENCE_LENGTH)\n",
        "preprocess_layer.set_vocabulary(vocab)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G-Cvd27y4qwt"
      },
      "outputs": [],
      "source": [
        "export_model = tf.keras.Sequential(\n",
        "    [preprocess_layer, model,\n",
        "     layers.Activation('sigmoid')])\n",
        "\n",
        "export_model.compile(\n",
        "    loss=losses.SparseCategoricalCrossentropy(from_logits=False),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pyg0B4zsc-UD"
      },
      "outputs": [],
      "source": [
        "# Create a test dataset of raw strings\n",
        "test_ds = all_labeled_data.take(VALIDATION_SIZE).batch(BATCH_SIZE)\n",
        "test_ds = configure_dataset(test_ds)\n",
        "loss, accuracy = export_model.evaluate(test_ds)\n",
        "print(\"Loss: \", loss)\n",
        "print(\"Accuracy: {:2.2%}\".format(accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o6Mm0Y9QYQwE"
      },
      "source": [
        "The loss and accuracy for the model on encoded validation set and the exported model on the raw validation set are the same, as expected."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Stk2BP8GE-qo"
      },
      "source": [
        "### Run inference on new data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-w1fQGJPD2Yh"
      },
      "outputs": [],
      "source": [
        "inputs = [\n",
        "    \"Join'd to th' Ionians with their flowing robes,\",  # Label: 1\n",
        "    \"the allies, and his armour flashed about him so that he seemed to all\",  # Label: 2\n",
        "    \"And with loud clangor of his arms he fell.\",  # Label: 0\n",
        "]\n",
        "predicted_scores = export_model.predict(inputs)\n",
        "predicted_labels = tf.argmax(predicted_scores, axis=1)\n",
        "for input, label in zip(inputs, predicted_labels):\n",
        "  print(\"Question: \", input)\n",
        "  print(\"Predicted label: \", label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9eA8TVdnA-3L"
      },
      "source": [
        "## Downloading more datasets using TensorFlow Datasets (TFDS)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2QFSxfZ3Vqsn"
      },
      "source": [
        "You can download many more datasets from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview). As an example, you will download the [IMDB Large Movie Review dataset](https://www.tensorflow.org/datasets/catalog/imdb_reviews), and use it to train a model for sentiment classification."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NzC65LOaVw0B"
      },
      "outputs": [],
      "source": [
        "train_ds = tfds.load(\n",
        "    'imdb_reviews',\n",
        "    split='train',\n",
        "    batch_size=BATCH_SIZE,\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XKGkgPBkFh0k"
      },
      "outputs": [],
      "source": [
        "val_ds = tfds.load(\n",
        "    'imdb_reviews',\n",
        "    split='train',\n",
        "    batch_size=BATCH_SIZE,\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BQjf3YZAb5Ne"
      },
      "source": [
        "Print a few examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bq1w8MnfWt2C"
      },
      "outputs": [],
      "source": [
        "for review_batch, label_batch in val_ds.take(1):\n",
        "  for i in range(5):\n",
        "    print(\"Review: \", review_batch[i].numpy())\n",
        "    print(\"Label: \", label_batch[i].numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q-lVaukyb75k"
      },
      "source": [
        "You can now preprocess the data and train a model as before. \n",
        "\n",
        "Note: You will use `losses.BinaryCrossentropy` instead of `losses.SparseCategoricalCrossentropy` for your model since this is a binary classification problem."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ciz2CxAsZw3Z"
      },
      "source": [
        "### Prepare the dataset for training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UzT_t9ihZLH4"
      },
      "outputs": [],
      "source": [
        "vectorize_layer = TextVectorization(\n",
        "    max_tokens=VOCAB_SIZE,\n",
        "    output_mode='int',\n",
        "    output_sequence_length=MAX_SEQUENCE_LENGTH)\n",
        "\n",
        "# Make a text-only dataset (without labels), then call adapt\n",
        "train_text = train_ds.map(lambda text, labels: text)\n",
        "vectorize_layer.adapt(train_text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zz-Xrd_ZZ4tB"
      },
      "outputs": [],
      "source": [
        "def vectorize_text(text, label):\n",
        "  text = tf.expand_dims(text, -1)\n",
        "  return vectorize_layer(text), label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ycn0Itd6g5aF"
      },
      "outputs": [],
      "source": [
        "train_ds = train_ds.map(vectorize_text)\n",
        "val_ds = val_ds.map(vectorize_text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jc11jQTlZ5lj"
      },
      "outputs": [],
      "source": [
        "# Configure datasets for performance as before\n",
        "train_ds = configure_dataset(train_ds)\n",
        "val_ds = configure_dataset(val_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SQzoYkaGZ82Z"
      },
      "source": [
        "### Train the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B9IOTLkyZ-a7"
      },
      "outputs": [],
      "source": [
        "model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=1)\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xLnDs5dhaBAk"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    loss=losses.BinaryCrossentropy(from_logits=True),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rq59QpNzaDMa"
      },
      "outputs": [],
      "source": [
        "history = model.fit(train_ds, validation_data=val_ds, epochs=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gCMWCEtyaEbR"
      },
      "outputs": [],
      "source": [
        "loss, accuracy = model.evaluate(val_ds)\n",
        "\n",
        "print(\"Loss: \", loss)\n",
        "print(\"Accuracy: {:2.2%}\".format(accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jGtqLXVnaaFy"
      },
      "source": [
        "### Export the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yE9WZARZaZr1"
      },
      "outputs": [],
      "source": [
        "export_model = tf.keras.Sequential(\n",
        "    [vectorize_layer, model,\n",
        "     layers.Activation('sigmoid')])\n",
        "\n",
        "export_model.compile(\n",
        "    loss=losses.SparseCategoricalCrossentropy(from_logits=False),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bhF8tDH-afoC"
      },
      "outputs": [],
      "source": [
        "# 0 --> negative review\n",
        "# 1 --> positive review\n",
        "inputs = [\n",
        "    \"This is a fantastic movie.\",\n",
        "    \"This is a bad movie.\",\n",
        "    \"This movie was so bad that it was good.\",\n",
        "    \"I will never say yes to watching this movie.\",\n",
        "]\n",
        "predicted_scores = export_model.predict(inputs)\n",
        "predicted_labels = [int(round(x[0])) for x in predicted_scores]\n",
        "for input, label in zip(inputs, predicted_labels):\n",
        "  print(\"Question: \", input)\n",
        "  print(\"Predicted label: \", label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q1KSXDFPWiPN"
      },
      "source": [
        "## Conclusion\n",
        "\n",
        "This tutorial demonstrated several ways to load and preprocess text. As a next step, you can explore additional tutorials on the website, or download new datasets from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "text.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
